{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "50fa980f-1bae-40c1-a1f3-f5f89bef60d3",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Copyright 2023 Google LLC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5da5f038-057f-4475-a783-95660f98238c",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Copyright 2023 Google LLC\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#      http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3a7c069-c441-4204-a905-59cbd9edc13a",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# MultiDiffusion with StyleAligned over SD v2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14178de7-d4c8-4881-ac1d-ff84bae57c6f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler\n",
    "import mediapy\n",
    "import sa_handler\n",
    "import pipeline_calls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "738cee0e-4d6e-4875-b4df-eadff6e27e7f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# init models\n",
    "model_ckpt = \"stabilityai/stable-diffusion-2-base\"\n",
    "scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder=\"scheduler\")\n",
    "pipeline = StableDiffusionPanoramaPipeline.from_pretrained(\n",
    "     model_ckpt, scheduler=scheduler, torch_dtype=torch.float16\n",
    ").to(\"cuda\")\n",
    "\n",
    "sa_args = sa_handler.StyleAlignedArgs(share_group_norm=True,\n",
    "                                      share_layer_norm=True,\n",
    "                                      share_attention=True,\n",
    "                                      adain_queries=True,\n",
    "                                      adain_keys=True,\n",
    "                                      adain_values=False,\n",
    "                                     )\n",
    "handler = sa_handler.Handler(pipeline)\n",
    "handler.register(sa_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea61e789-2814-4820-8ae7-234c3c6640a0",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# run MultiDiffusion with StyleAligned\n",
    "\n",
    "reference_prompt = \"a beautiful papercut art design\"\n",
    "target_prompts = [\"mountains in a beautiful papercut art design\", \"giraffes in a beautiful papercut art design\"]\n",
    "view_batch_size = 25  # adjust according to VRAM size\n",
    "reference_latent = torch.randn(1, 4, 64, 64,)\n",
    "for target_prompt in target_prompts:\n",
    "    images = pipeline_calls.panorama_call(pipeline, [reference_prompt, target_prompt], reference_latent=reference_latent, view_batch_size=view_batch_size)\n",
    "    mediapy.show_images(images, titles=[\"reference\", \"result\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "791a9b28-f0ce-4fd0-9f3c-594281c2ae56",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}